{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pO4-CY_TCZZS"
      },
      "source": [
        "# Train a Simple Audio Recognition Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BaFfr7DHRmGF"
      },
      "source": [
        "This notebook demonstrates how to train a 20 kB [Simple Audio Recognition](https://www.tensorflow.org/tutorials/sequences/audio_recognition) model to recognize keywords in speech.\n",
        "\n",
        "The model created in this notebook is used in the [micro_speech](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech) example for [TensorFlow Lite for MicroControllers](https://www.tensorflow.org/lite/microcontrollers/overview).\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/train/train_micro_speech_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/train/train_micro_speech_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XaVtYN4nlCft"
      },
      "source": [
        "**Training is much faster using GPU acceleration.** Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and set **Hardware accelerator: GPU**. Training 15,000 iterations will take 1.5 - 2 hours on a GPU runtime.\n",
        "\n",
        "## Configure Defaults\n",
        "\n",
        "**MODIFY** the following constants for your specific use case."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ludfxbNIaegy"
      },
      "outputs": [],
      "source": [
        "# A comma-delimited list of the words you want to train for.\n",
        "# The options are: yes,no,up,down,left,right,on,off,stop,go\n",
        "# All the other words will be used to train an \"unknown\" label and silent\n",
        "# audio data with no spoken words will be used to train a \"silence\" label.\n",
        "WANTED_WORDS = \"yes,no\"\n",
        "\n",
        "# The number of steps and learning rates can be specified as comma-separated\n",
        "# lists to define the rate at each stage. For example,\n",
        "# TRAINING_STEPS=12000,3000 and LEARNING_RATE=0.001,0.0001\n",
        "# will run 12,000 training loops in total, with a rate of 0.001 for the first\n",
        "# 8,000, and 0.0001 for the final 3,000.\n",
        "TRAINING_STEPS = \"12000,3000\"\n",
        "LEARNING_RATE = \"0.001,0.0001\"\n",
        "\n",
        "# Calculate the total number of steps, which is used to identify the checkpoint\n",
        "# file name.\n",
        "TOTAL_STEPS = str(sum(map(lambda string: int(string), TRAINING_STEPS.split(\",\"))))\n",
        "\n",
        "# Print the configuration to confirm it\n",
        "print(\"Training these words: %s\" % WANTED_WORDS)\n",
        "print(\"Training steps in each stage: %s\" % TRAINING_STEPS)\n",
        "print(\"Learning rate in each stage: %s\" % LEARNING_RATE)\n",
        "print(\"Total number of training steps: %s\" % TOTAL_STEPS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gCgeOpvY9pAi"
      },
      "source": [
        "**DO NOT MODIFY** the following constants as they include filepaths used in this notebook and data that is shared during training and inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Nd1iM1o2ymvA"
      },
      "outputs": [],
      "source": [
        "# Calculate the percentage of 'silence' and 'unknown' training samples required\n",
        "# to ensure that we have equal number of samples for each label.\n",
        "number_of_labels = WANTED_WORDS.count(',') + 1\n",
        "number_of_total_labels = number_of_labels + 2 # for 'silence' and 'unknown' label\n",
        "equal_percentage_of_training_samples = int(100.0/(number_of_total_labels))\n",
        "SILENT_PERCENTAGE = equal_percentage_of_training_samples\n",
        "UNKNOWN_PERCENTAGE = equal_percentage_of_training_samples\n",
        "\n",
        "# Constants which are shared during training and inference\n",
        "PREPROCESS = 'micro'\n",
        "WINDOW_STRIDE = 20\n",
        "MODEL_ARCHITECTURE = 'tiny_conv' # Other options include: single_fc, conv,\n",
        "                      # low_latency_conv, low_latency_svdf, tiny_embedding_conv\n",
        "\n",
        "# Constants used during training only\n",
        "VERBOSITY = 'WARN'\n",
        "EVAL_STEP_INTERVAL = '1000'\n",
        "SAVE_STEP_INTERVAL = '1000'\n",
        "\n",
        "# Constants for training directories and filepaths\n",
        "DATASET_DIR =  'dataset/'\n",
        "LOGS_DIR = 'logs/'\n",
        "TRAIN_DIR = 'train/' # for training checkpoints and other files.\n",
        "\n",
        "# Constants for inference directories and filepaths\n",
        "import os\n",
        "MODELS_DIR = 'models'\n",
        "if not os.path.exists(MODELS_DIR):\n",
        "  os.mkdir(MODELS_DIR)\n",
        "MODEL_TF = os.path.join(MODELS_DIR, 'model.pb')\n",
        "MODEL_TFLITE = os.path.join(MODELS_DIR, 'model.tflite')\n",
        "FLOAT_MODEL_TFLITE = os.path.join(MODELS_DIR, 'float_model.tflite')\n",
        "MODEL_TFLITE_MICRO = os.path.join(MODELS_DIR, 'model.cc')\n",
        "SAVED_MODEL = os.path.join(MODELS_DIR, 'saved_model')\n",
        "\n",
        "QUANT_INPUT_MIN = 0.0\n",
        "QUANT_INPUT_MAX = 26.0\n",
        "QUANT_INPUT_RANGE = QUANT_INPUT_MAX - QUANT_INPUT_MIN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6rLYpvtg9P4o"
      },
      "source": [
        "## Setup Environment\n",
        "\n",
        "Install Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ed_XpUrU5DvY"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "T9Ty5mR58E4i"
      },
      "source": [
        "**DELETE** any old data from previous runs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "APGx0fEh7hFF"
      },
      "outputs": [],
      "source": [
        "!rm -rf {DATASET_DIR} {LOGS_DIR} {TRAIN_DIR} {MODELS_DIR}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GfEUlfFBizio"
      },
      "source": [
        "Clone the TensorFlow Github Repository, which contains the relevant code required to run this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yZArmzT85SLq"
      },
      "outputs": [],
      "source": [
        "!git clone -q --depth 1 https://github.com/tensorflow/tensorflow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nS9swHLSi7Bi"
      },
      "source": [
        "Load TensorBoard to visualize the accuracy and loss as training proceeds.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "q4qF1VxP3UE4"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard\n",
        "%tensorboard --logdir {LOGS_DIR}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "x1J96Ron-O4R"
      },
      "source": [
        "## Training\n",
        "\n",
        "The following script downloads the dataset and begin training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VJsEZx6lynbY"
      },
      "outputs": [],
      "source": [
        "!python tensorflow/tensorflow/examples/speech_commands/train.py \\\n",
        "--data_dir={DATASET_DIR} \\\n",
        "--wanted_words={WANTED_WORDS} \\\n",
        "--silence_percentage={SILENT_PERCENTAGE} \\\n",
        "--unknown_percentage={UNKNOWN_PERCENTAGE} \\\n",
        "--preprocess={PREPROCESS} \\\n",
        "--window_stride={WINDOW_STRIDE} \\\n",
        "--model_architecture={MODEL_ARCHITECTURE} \\\n",
        "--how_many_training_steps={TRAINING_STEPS} \\\n",
        "--learning_rate={LEARNING_RATE} \\\n",
        "--train_dir={TRAIN_DIR} \\\n",
        "--summaries_dir={LOGS_DIR} \\\n",
        "--verbosity={VERBOSITY} \\\n",
        "--eval_step_interval={EVAL_STEP_INTERVAL} \\\n",
        "--save_step_interval={SAVE_STEP_INTERVAL}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UczQKtqLi7OJ"
      },
      "source": [
        "# Skipping the training\n",
        "\n",
        "If you don't want to spend an hour or two training the model from scratch, you can download pretrained checkpoints by uncommenting the lines below (removing the '#'s at the start of each line) and running them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RZw3VNlnla-J"
      },
      "outputs": [],
      "source": [
        "#!curl -O \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_micro_train_2020_05_10.tgz\"\n",
        "#!tar xzf speech_micro_train_2020_05_10.tgz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XQUJLrdS-ftl"
      },
      "source": [
        "## Generate a TensorFlow Model for Inference\n",
        "\n",
        "Combine relevant training results (graph, weights, etc) into a single file for inference. This process is known as freezing a model and the resulting model is known as a frozen model/graph, as it cannot be further re-trained after this process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xyc3_eLh9sAg"
      },
      "outputs": [],
      "source": [
        "!rm -rf {SAVED_MODEL}\n",
        "!python tensorflow/tensorflow/examples/speech_commands/freeze.py \\\n",
        "--wanted_words=$WANTED_WORDS \\\n",
        "--window_stride_ms=$WINDOW_STRIDE \\\n",
        "--preprocess=$PREPROCESS \\\n",
        "--model_architecture=$MODEL_ARCHITECTURE \\\n",
        "--start_checkpoint=$TRAIN_DIR$MODEL_ARCHITECTURE'.ckpt-'{TOTAL_STEPS} \\\n",
        "--save_format=saved_model \\\n",
        "--output_file={SAVED_MODEL}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_DBGDxVI-nKG"
      },
      "source": [
        "## Generate a TensorFlow Lite Model\n",
        "\n",
        "Convert the frozen graph into a TensorFlow Lite model, which is fully quantized for use with embedded devices.\n",
        "\n",
        "The following cell will also print the model size, which will be under 20 kilobytes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RIitkqvGWmre"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "# We add this path so we can import the speech processing modules.\n",
        "sys.path.append(\"/content/tensorflow/tensorflow/examples/speech_commands/\")\n",
        "import input_data\n",
        "import models\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kzqECqMxgBh4"
      },
      "outputs": [],
      "source": [
        "SAMPLE_RATE = 16000\n",
        "CLIP_DURATION_MS = 1000\n",
        "WINDOW_SIZE_MS = 30.0\n",
        "FEATURE_BIN_COUNT = 40\n",
        "BACKGROUND_FREQUENCY = 0.8\n",
        "BACKGROUND_VOLUME_RANGE = 0.1\n",
        "TIME_SHIFT_MS = 100.0\n",
        "\n",
        "DATA_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'\n",
        "VALIDATION_PERCENTAGE = 10\n",
        "TESTING_PERCENTAGE = 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rNQdAplJV1fz"
      },
      "outputs": [],
      "source": [
        "model_settings = models.prepare_model_settings(\n",
        "    len(input_data.prepare_words_list(WANTED_WORDS.split(','))),\n",
        "    SAMPLE_RATE, CLIP_DURATION_MS, WINDOW_SIZE_MS,\n",
        "    WINDOW_STRIDE, FEATURE_BIN_COUNT, PREPROCESS)\n",
        "audio_processor = input_data.AudioProcessor(\n",
        "    DATA_URL, DATASET_DIR,\n",
        "    SILENT_PERCENTAGE, UNKNOWN_PERCENTAGE,\n",
        "    WANTED_WORDS.split(','), VALIDATION_PERCENTAGE,\n",
        "    TESTING_PERCENTAGE, model_settings, LOGS_DIR)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lBj_AyCh1cC0"
      },
      "outputs": [],
      "source": [
        "with tf.compat.v1.Session() as sess:\n",
        "  float_converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)\n",
        "  float_tflite_model = float_converter.convert()\n",
        "  float_tflite_model_size = open(FLOAT_MODEL_TFLITE, \"wb\").write(float_tflite_model)\n",
        "  print(\"Float model is %d bytes\" % float_tflite_model_size)\n",
        "\n",
        "  converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL)\n",
        "  converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "  converter.inference_input_type = tf.int8\n",
        "  converter.inference_output_type = tf.int8\n",
        "  def representative_dataset_gen():\n",
        "    for i in range(100):\n",
        "      data, _ = audio_processor.get_data(1, i*1, model_settings,\n",
        "                                         BACKGROUND_FREQUENCY, \n",
        "                                         BACKGROUND_VOLUME_RANGE,\n",
        "                                         TIME_SHIFT_MS,\n",
        "                                         'testing',\n",
        "                                         sess)\n",
        "      flattened_data = np.array(data.flatten(), dtype=np.float32).reshape(1, 1960)\n",
        "      yield [flattened_data]\n",
        "  converter.representative_dataset = representative_dataset_gen\n",
        "  tflite_model = converter.convert()\n",
        "  tflite_model_size = open(MODEL_TFLITE, \"wb\").write(tflite_model)\n",
        "  print(\"Quantized model is %d bytes\" % tflite_model_size)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EeLiDZTbLkzv"
      },
      "source": [
        "## Testing the TensorFlow Lite model's accuracy\n",
        "\n",
        "Verify that the model we've exported is still accurate, using the TF Lite Python API and our test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 110,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wQsEteKRLryJ"
      },
      "outputs": [],
      "source": [
        "# Helper function to run inference\n",
        "def run_tflite_inference(tflite_model_path, model_type=\"Float\"):\n",
        "  # Load test data\n",
        "  np.random.seed(0) # set random seed for reproducible test results.\n",
        "  with tf.compat.v1.Session() as sess:\n",
        "    test_data, test_labels = audio_processor.get_data(\n",
        "        -1, 0, model_settings, BACKGROUND_FREQUENCY, BACKGROUND_VOLUME_RANGE,\n",
        "        TIME_SHIFT_MS, 'testing', sess)\n",
        "  test_data = np.expand_dims(test_data, axis=1).astype(np.float32)\n",
        "\n",
        "  # Initialize the interpreter\n",
        "  interpreter = tf.lite.Interpreter(tflite_model_path,\n",
        "                                    experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF)\n",
        "  interpreter.allocate_tensors()\n",
        "\n",
        "  input_details = interpreter.get_input_details()[0]\n",
        "  output_details = interpreter.get_output_details()[0]\n",
        "\n",
        "  # For quantized models, manually quantize the input data from float to integer\n",
        "  if model_type == \"Quantized\":\n",
        "    input_scale, input_zero_point = input_details[\"quantization\"]\n",
        "    test_data = test_data / input_scale + input_zero_point\n",
        "    test_data = test_data.astype(input_details[\"dtype\"])\n",
        "\n",
        "  correct_predictions = 0\n",
        "  for i in range(len(test_data)):\n",
        "    interpreter.set_tensor(input_details[\"index\"], test_data[i])\n",
        "    interpreter.invoke()\n",
        "    output = interpreter.get_tensor(output_details[\"index\"])[0]\n",
        "    top_prediction = output.argmax()\n",
        "    correct_predictions += (top_prediction == test_labels[i])\n",
        "\n",
        "  print('%s model accuracy is %f%% (Number of test samples=%d)' % (\n",
        "      model_type, (correct_predictions * 100) / len(test_data), len(test_data)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 111,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "l-pD52Na6jRa"
      },
      "outputs": [],
      "source": [
        "# Compute float model accuracy\n",
        "run_tflite_inference(FLOAT_MODEL_TFLITE)\n",
        "\n",
        "# Compute quantized model accuracy\n",
        "run_tflite_inference(MODEL_TFLITE, model_type='Quantized')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dt6Zqbxu-wIi"
      },
      "source": [
        "## Generate a TensorFlow Lite for MicroControllers Model\n",
        "Convert the TensorFlow Lite model into a C source file that can be loaded by TensorFlow Lite for Microcontrollers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "XohZOTjR8ZyE"
      },
      "outputs": [],
      "source": [
        "# Install xxd if it is not available\n",
        "!apt-get update && apt-get -qq install xxd\n",
        "# Convert to a C source file\n",
        "!xxd -i {MODEL_TFLITE} > {MODEL_TFLITE_MICRO}\n",
        "# Update variable names\n",
        "REPLACE_TEXT = MODEL_TFLITE.replace('/', '_').replace('.', '_')\n",
        "!sed -i 's/'{REPLACE_TEXT}'/g_model/g' {MODEL_TFLITE_MICRO}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2pQnN0i_-0L2"
      },
      "source": [
        "## Deploy to a Microcontroller\n",
        "\n",
        "Follow the instructions in the [micro_speech](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech) README.md for [TensorFlow Lite for MicroControllers](https://www.tensorflow.org/lite/microcontrollers/overview) to deploy this model on a specific microcontroller.\n",
        "\n",
        "**Reference Model:** If you have not modified this notebook, you can follow the instructions as is, to deploy the model. Refer to the [`micro_speech/train/models`](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/train/models) directory to access the models generated in this notebook.\n",
        "\n",
        "**New Model:** If you have generated a new model to identify different words: (i) Update `kCategoryCount` and `kCategoryLabels` in [`micro_speech/micro_features/micro_model_settings.h`](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h) and (ii) Update the values assigned to the variables defined in [`micro_speech/micro_features/model.cc`](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc) with values displayed after running the following cell."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eoYyh0VU8pca"
      },
      "outputs": [],
      "source": [
        "# Print the C source file\n",
        "!cat {MODEL_TFLITE_MICRO}"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "train_micro_speech_model.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3.9.13 ('venv': venv)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "22cb1d09959a40fdc50ccd77b5464bb60602aea13b58d7f13d7eaffcd0bc7c7d"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
